{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F\n",
    "from torchsummary import summary\n",
    "\n",
    "class RecCNN(nn.Module):\n",
    "    def __init__(self, num_classes):\n",
    "        super().__init__()\n",
    "        self.convnet = nn.Sequential(\n",
    "            nn.Conv2d(1, 64, 3, padding='same'), nn.ReLU(),\n",
    "            nn.MaxPool2d(2, stride=2),\n",
    "            \n",
    "            nn.Conv2d(64, 16, 3, padding='same'), nn.ReLU(),\n",
    "            nn.MaxPool2d(2, stride=2),\n",
    "        )\n",
    "        self.fc1 = nn.Linear(32*32*16, 32)\n",
    "        self.fc2 = nn.Linear(32, num_classes)\n",
    "        # self.softmax = nn.Softmax()\n",
    "    def forward(self, x):\n",
    "        x = torch.flatten(self.convnet(x), 1)\n",
    "        x = self.fc1(x)\n",
    "        x = self.fc2(x)\n",
    "        # c = self.softmax(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.rand((1, 1, 128, 128))\n",
    "model = RecCNN(num_classes=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 128, 128]             640\n",
      "              ReLU-2         [-1, 64, 128, 128]               0\n",
      "         MaxPool2d-3           [-1, 64, 64, 64]               0\n",
      "            Conv2d-4           [-1, 16, 64, 64]           9,232\n",
      "              ReLU-5           [-1, 16, 64, 64]               0\n",
      "         MaxPool2d-6           [-1, 16, 32, 32]               0\n",
      "            Linear-7                   [-1, 32]         524,320\n",
      "            Linear-8                    [-1, 5]             165\n",
      "================================================================\n",
      "Total params: 534,357\n",
      "Trainable params: 534,357\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.06\n",
      "Forward/backward pass size (MB): 19.13\n",
      "Params size (MB): 2.04\n",
      "Estimated Total Size (MB): 21.23\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(model, (1, 128, 128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from thop import profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.\n",
      "[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.\n",
      "MACs (G):  47.710368\n",
      "Params (M):  0.534357\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand((1, 1, 128, 128))\n",
    "model = RecCNN(num_classes=5)\n",
    "macs, params = profile(model, inputs=(x, ))\n",
    "print('MACs (G): ', macs/1000**2)\n",
    "print('Params (M): ', params/1000**2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('torch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "78a29cc2c05d3ee8d935820ad86792723c958d8c7f217aee9aa88e38f878a5d1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
